/*
 * Copyright 2018 The Trickster Authors
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

// Package tls provides functionality for use when conducting tests with TLS
package tls

import (
	"bytes"
	"crypto/rand"
	"crypto/rsa"
	"crypto/x509"
	"crypto/x509/pkix"
	"encoding/pem"
	"math/big"
	"net"
	"os"
	"strings"
	"time"

	"github.com/trickstercache/trickster/v2/pkg/checksum/md5"
)

// WriteTestKeyAndCert writes a self-signed test TLS key and cert. Paths should
// use t.TempDir() to ensure they are auto-cleaned, otherwise, you must cleanup
func WriteTestKeyAndCert(isCA bool, keyPath, certPath string) error {
	k, c, err := GetTestKeyAndCert(isCA)
	if err != nil {
		return err
	}

	if !isCA || keyPath != "" {
		err := os.WriteFile(keyPath, k, 0o600)
		if err != nil {
			return err
		}
	}

	if certPath != "" {
		err := os.WriteFile(certPath, c, 0o600)
		if err != nil {
			return err
		}
	}

	return nil
}

// GetTestKeyAndCert returns a self-sign test TLS key and certificate
func GetTestKeyAndCert(isCA bool) ([]byte, []byte, error) {
	priv, _ := rsa.GenerateKey(rand.Reader, 2048)
	notBefore := time.Now()
	notAfter := notBefore.Add(time.Minute * 5)

	serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
	serialNumber, _ := rand.Int(rand.Reader, serialNumberLimit)

	template := x509.Certificate{
		SerialNumber: serialNumber,
		Subject: pkix.Name{
			Organization: []string{"Trickster Test Certificate DO NOT USE"},
		},
		NotBefore:             notBefore,
		NotAfter:              notAfter,
		KeyUsage:              x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
		ExtKeyUsage:           []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
		BasicConstraintsValid: true,
		IPAddresses:           []net.IP{net.ParseIP("127.0.0.1")},
		DNSNames:              []string{"localhost"},
	}
	if isCA {
		template.IsCA = true
		template.KeyUsage |= x509.KeyUsageCertSign
	}
	derBytes, _ := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
	keyBuff := bytes.NewBuffer(nil)
	certBuff := bytes.NewBuffer(nil)
	pem.Encode(certBuff, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
	privBytes, _ := x509.MarshalPKCS8PrivateKey(priv)
	pem.Encode(keyBuff, &pem.Block{Type: "PRIVATE KEY", Bytes: privBytes})
	return keyBuff.Bytes(), certBuff.Bytes(), nil
}

// GetTestKeyAndCertFiles returns the paths to key and certificate files generated by GetTestKeyAndCert
func GetTestKeyAndCertFiles(condition string) (string, string, func(), error) {
	k, c, _ := GetTestKeyAndCert(strings.HasPrefix(condition, "ca"))
	hash := md5.Checksum("trickster " + string(k) + string(c))[0:6]

	kf := "./test." + hash + ".key.pem"
	cf := "./test." + hash + ".cert.pem"

	switch condition {
	case "invalid-key":
		k = []byte("invalid key data\n")
	case "invalid-cert":
		c = []byte("invalid cert data\n")
	}

	err := os.WriteFile(kf, k, 0o600)
	if err != nil {
		return "", "", nil, err
	}
	err = os.WriteFile(cf, c, 0o600)
	if err != nil {
		return "", "", func() { os.Remove(kf) }, err
	}

	return kf, cf, func() { os.Remove(kf); os.Remove(cf) }, nil
}
